- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.6k
Semantic text - CCS POC - DO NOT MERGE #132411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo"); | ||
| queryBuilder.setModelRegistrySupplier(() -> modelRegistry); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This admittedly a hacky way to pass the model registry to the semantic query, but I was looking for a way to do it that didn't involve a lot of refactoring. The proper way to do this is likely through the constructor.
|  | ||
| String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); | ||
| SetOnce<InferenceServiceResults> inferenceResultsSupplier = new SetOnce<>(); | ||
| boolean noInferenceResults = false; | ||
| if (inferenceId != null) { | ||
| InferenceAction.Request inferenceRequest = new InferenceAction.Request( | ||
| TaskType.ANY, | ||
| inferenceId, | ||
| null, | ||
| null, | ||
| null, | ||
| List.of(query), | ||
| Map.of(), | ||
| InputType.INTERNAL_SEARCH, | ||
| null, | ||
| false | ||
| ); | ||
|  | ||
| queryRewriteContext.registerAsyncAction( | ||
| (client, listener) -> executeAsyncWithOrigin( | ||
| client, | ||
| ML_ORIGIN, | ||
| InferenceAction.INSTANCE, | ||
| inferenceRequest, | ||
| listener.delegateFailureAndWrap((l, inferenceResponse) -> { | ||
| inferenceResultsSupplier.set(inferenceResponse.getResults()); | ||
| l.onResponse(null); | ||
| }) | ||
| ) | ||
| ); | ||
| MapEmbeddingsProvider currentEmbeddingsProvider; | ||
| if (embeddingsProvider != null) { | ||
| if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) { | ||
| currentEmbeddingsProvider = mapEmbeddingsProvider; | ||
| } else { | ||
| throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider"); | ||
| } | ||
| } else { | ||
| // The inference ID can be null if either the field name or index name(s) are invalid (or both). | ||
| // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. | ||
| // Invalid index names will be handled in the transport layer, when the query is sent to the shard. | ||
| // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. | ||
| noInferenceResults = true; | ||
| currentEmbeddingsProvider = new MapEmbeddingsProvider(); | ||
| } | ||
|  | ||
| return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults); | ||
| boolean modified = false; | ||
| if (queryRewriteContext.hasAsyncActions() == false) { | ||
| ModelRegistry modelRegistry = modelRegistrySupplier.get(); | ||
| if (modelRegistry == null) { | ||
| throw new IllegalStateException("Model registry has not been set"); | ||
| } | ||
|  | ||
| Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); | ||
| for (String inferenceId : inferenceIds) { | ||
| MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); | ||
| InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); | ||
|  | ||
| if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) { | ||
| InferenceAction.Request inferenceRequest = new InferenceAction.Request( | ||
| TaskType.ANY, | ||
| inferenceId, | ||
| null, | ||
| null, | ||
| null, | ||
| List.of(query), | ||
| Map.of(), | ||
| InputType.INTERNAL_SEARCH, | ||
| null, | ||
| false | ||
| ); | ||
|  | ||
| queryRewriteContext.registerAsyncAction( | ||
| (client, listener) -> executeAsyncWithOrigin( | ||
| client, | ||
| ML_ORIGIN, | ||
| InferenceAction.INSTANCE, | ||
| inferenceRequest, | ||
| listener.delegateFailureAndWrap((l, inferenceResponse) -> { | ||
| currentEmbeddingsProvider.addEmbeddings( | ||
| inferenceEndpointKey, | ||
| validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) | ||
| ); | ||
| l.onResponse(null); | ||
| }) | ||
| ) | ||
| ); | ||
|  | ||
| modified = true; | ||
| } | ||
| } | ||
| } | ||
|  | ||
| return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this; | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic demonstrates a way to reuse embeddings cross-cluster, when they are compatible. For the sake of this POC I chose to use the combination of inference ID + minimal service settings to qualify inference endpoints as equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could make this even more simple. For example, apply a warning if just the inference ids are different. It is just a warning after all, that there is some potential detected difference here. This may not to be perfect as far as calculating the model registry to detect different models. We could also consider a flag for lenient mode to suppress warnings if people intentionally want to use different inference IDs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood how there could be a gap here in detecting compatible embeddings. If we want to be more conservative here, we could use something like cluster name + inference ID to identify embeddings in the map. That would mean no embedding reuse cross-cluster though.
Setting a warning doesn't work for CCS though as warning headers are not transmitted back to the primary cluster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice POC! I've left some high level comments on functionality as this is still a POC.
| Rewriteable.rewriteAndFetch( | ||
| original, | ||
| searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null), | ||
| searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null, original.isCcsMinimizeRoundtrips()), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there precedent for having CCS-specific knobs like this in generic search code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, but there's a good case to be made why this is necessary. The CCS mode affects the query rewrite cycle, thus we need a way to know about it within that context. Info is passed to query rewrite via QueryRewriteContext, thus this implementation.
| import java.io.IOException; | ||
| import java.util.Objects; | ||
|  | ||
| public class InferenceEndpointKey implements Writeable { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this conceptually, but nitpicky - I would like to find a better name for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I wasn't focusing too much on names, just proving out functionality
|  | ||
| ModelRegistry modelRegistry = modelRegistrySupplier.get(); | ||
| if (modelRegistry == null) { | ||
| throw new IllegalStateException("Model registry has not been set"); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to test this, to make sure it actually should always return a 500/trigger a serverless alert, similar to some other alerts we've been seeing for semantic queries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of those "should never happen in production" errors. If it does, it's a symptom of an upstream problem that we should be alerted to.
| ); | ||
| MapEmbeddingsProvider currentEmbeddingsProvider; | ||
| if (embeddingsProvider != null) { | ||
| if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this check should be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another one of those "this should never fail in production" cases. If we get here, it means we're performing coordinator node rewrite. That means that we're performing inference, either on a local or remote cluster.
If we're on a local cluster, we can assume that this node built the initial query and thus the embeddings provider should be a MapEmbeddingsProvider.
If we're on a remote cluster, we can assume that the the primary (i.e. local) cluster allows semantic queries to perform CCS, which is directly correlated to usage of MapEmbeddingsProvider.
Either way, we need the representation to be MapEmbeddingsProvider so that we can call addEmbeddings later, hence this check.
|  | ||
| String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); | ||
| SetOnce<InferenceServiceResults> inferenceResultsSupplier = new SetOnce<>(); | ||
| boolean noInferenceResults = false; | ||
| if (inferenceId != null) { | ||
| InferenceAction.Request inferenceRequest = new InferenceAction.Request( | ||
| TaskType.ANY, | ||
| inferenceId, | ||
| null, | ||
| null, | ||
| null, | ||
| List.of(query), | ||
| Map.of(), | ||
| InputType.INTERNAL_SEARCH, | ||
| null, | ||
| false | ||
| ); | ||
|  | ||
| queryRewriteContext.registerAsyncAction( | ||
| (client, listener) -> executeAsyncWithOrigin( | ||
| client, | ||
| ML_ORIGIN, | ||
| InferenceAction.INSTANCE, | ||
| inferenceRequest, | ||
| listener.delegateFailureAndWrap((l, inferenceResponse) -> { | ||
| inferenceResultsSupplier.set(inferenceResponse.getResults()); | ||
| l.onResponse(null); | ||
| }) | ||
| ) | ||
| ); | ||
| MapEmbeddingsProvider currentEmbeddingsProvider; | ||
| if (embeddingsProvider != null) { | ||
| if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) { | ||
| currentEmbeddingsProvider = mapEmbeddingsProvider; | ||
| } else { | ||
| throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider"); | ||
| } | ||
| } else { | ||
| // The inference ID can be null if either the field name or index name(s) are invalid (or both). | ||
| // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. | ||
| // Invalid index names will be handled in the transport layer, when the query is sent to the shard. | ||
| // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. | ||
| noInferenceResults = true; | ||
| currentEmbeddingsProvider = new MapEmbeddingsProvider(); | ||
| } | ||
|  | ||
| return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults); | ||
| boolean modified = false; | ||
| if (queryRewriteContext.hasAsyncActions() == false) { | ||
| ModelRegistry modelRegistry = modelRegistrySupplier.get(); | ||
| if (modelRegistry == null) { | ||
| throw new IllegalStateException("Model registry has not been set"); | ||
| } | ||
|  | ||
| Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); | ||
| for (String inferenceId : inferenceIds) { | ||
| MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); | ||
| InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); | ||
|  | ||
| if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) { | ||
| InferenceAction.Request inferenceRequest = new InferenceAction.Request( | ||
| TaskType.ANY, | ||
| inferenceId, | ||
| null, | ||
| null, | ||
| null, | ||
| List.of(query), | ||
| Map.of(), | ||
| InputType.INTERNAL_SEARCH, | ||
| null, | ||
| false | ||
| ); | ||
|  | ||
| queryRewriteContext.registerAsyncAction( | ||
| (client, listener) -> executeAsyncWithOrigin( | ||
| client, | ||
| ML_ORIGIN, | ||
| InferenceAction.INSTANCE, | ||
| inferenceRequest, | ||
| listener.delegateFailureAndWrap((l, inferenceResponse) -> { | ||
| currentEmbeddingsProvider.addEmbeddings( | ||
| inferenceEndpointKey, | ||
| validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) | ||
| ); | ||
| l.onResponse(null); | ||
| }) | ||
| ) | ||
| ); | ||
|  | ||
| modified = true; | ||
| } | ||
| } | ||
| } | ||
|  | ||
| return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this; | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could make this even more simple. For example, apply a warning if just the inference ids are different. It is just a warning after all, that there is some potential detected difference here. This may not to be perfect as far as calculating the model registry to detect different models. We could also consider a flag for lenient mode to suppress warnings if people intentionally want to use different inference IDs.
| ) | ||
| ); | ||
| } else if (inferenceResultsList.size() > 1) { | ||
| // The inference call should truncate if the query is too large. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if that's the case for all models? For example we warn in our docs that OpenAI will error if BYO chunks are too large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that we are handling query-time inference here, which will never chunk. We should always get back one inference result. If we get back more, something has gone horribly wrong in the Inference API that we want to know about, hence this check.
This comment may not be fully technically correct in that some providers may error instead of truncate on huge input. However, in that case, we will still get back only one inference result, it will just be an instance of ErrorInferenceResults.
| Superceded by #133466 | 
This is a POC implementation of CCS support for the
semanticquery whenccs_minimize_roundtrips=true.It implements:
semanticquery multi-index handling (adapted from Support using the semantic query across multiple inference IDs #120755)semanticquery CCS support whenccs_minimize_roundtrips=trueccs_minimize_roundtrips=false